{% extends 'base.attached.class' %}

{% block content %}
{% if is_test or to_json %}
import java.util.Arrays;

{% endif %}
class {{ class_name }} {

    private int[] lefts;
    private int[] rights;
    private double[] thresholds;
    private int[] indices;
    private int[][] classes;

    public {{ class_name }}(int[] lefts, int[] rights, double[] thresholds, int[] indices, int[][] classes) {
        this.lefts = lefts;
        this.rights = rights;
        this.thresholds = thresholds;
        this.indices = indices;
        this.classes = classes;
    }

    private int findMax(int[] nums) {
        int i = 0, l = nums.length, idx = 0;
        for (i = 0; i < l; i++) {
            idx = nums[i] > nums[idx] ? i : idx;
        }
        return idx;
    }

    private double[] normVals(int[] nums) {
        int i = 0, l = nums.length;
        double[] result = new double[l];
        double sum = 0.;
        for (i = 0; i < l; i++) {
            sum += nums[i];
        }
        if(sum == 0) {
            for (i = 0; i < l; i++) {
                result[i] = 1.0 / nums.length;
            }
        } else {
            for (i = 0; i < l; i++) {
                result[i] = nums[i] / sum;
            }
        }
        return result;
    }

    private int predict(double[] features, int node) {
        if (this.thresholds[node] != -2) {
            if (features[this.indices[node]] <= this.thresholds[node]) {
                return predict(features, this.lefts[node]);
            } else {
                return predict(features, this.rights[node]);
            }
        }
        return findMax(this.classes[node]);
    }

    public int predict(double[] features) {
        return this.predict(features, 0);
    }

    private double[] predictProba(double[] features, int node) {
        if (this.thresholds[node] != -2) {
            if (features[this.indices[node]] <= this.thresholds[node]) {
                return this.predictProba(features, this.lefts[node]);
            } else {
                return this.predictProba(features, this.rights[node]);
            }
        }
        return normVals(this.classes[node]);
    }

    public double[] predictProba (double[] features) {
        return this.predictProba(features, 0);
    }

    public static void main(String[] args) {
        int nFeatures = {{ n_features }};
        if (args.length != nFeatures) {
            throw new IllegalArgumentException("You have to pass " +  String.valueOf(nFeatures) + " features.");
        }

        // Features:
        double[] features = new double[args.length];
        for (int i = 0, l = args.length; i < l; i++) {
            features[i] = Double.parseDouble(args[i]);
        }

        // Model data:
        {{ lefts }}
        {{ rights }}
        {{ thresholds }}
        {{ indices }}
        {{ classes }}

        // Estimator:
        {{ class_name }} clf = new {{ class_name }}(lefts, rights, thresholds, indices, classes);

        {% if is_test or to_json %}
        // Get JSON:
        int prediction = clf.predict(features);
        double[] probabilities = clf.predictProba(features);
        System.out.println("{\"predict\": " + String.valueOf(prediction) + ", \"predict_proba\": " + String.join(",", Arrays.toString(probabilities)) + "}");
        {% else %}
        // Get class prediction:
        int prediction = clf.predict(features);
        System.out.println("Predicted class: #" + String.valueOf(prediction));

        // Get class probabilities:
        double[] probabilities = clf.predictProba(features);
        for (int i = 0; i < probabilities.length; i++) {
            System.out.println("Probability of class #" + i + " : " + String.valueOf(probabilities[i]));
        }
		    {% endif %}

    }
}
{% endblock %}